{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import os\n",
    "import numpy as np\n",
    "from glob import glob\n",
    "from tqdm import tqdm\n",
    "from scipy.misc import imsave\n",
    "from PIL import Image"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "prob_files=glob('/data/DW/ordinary/DRN_RCN/' +\n",
    "                'DataFusion2020_TRAD_SEMISUPERVISED/pretes_track2/*/prob/*.npy')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['/data/DW/ordinary/DRN_RCN/DataFusion2020_TRAD_SEMISUPERVISED/pretes_track2/unet_senext50_batch16/prob/DFC2020_tes_unet_senext_16_.npy',\n",
       " '/data/DW/ordinary/DRN_RCN/DataFusion2020_TRAD_SEMISUPERVISED/pretes_track2/unet_resnext50_batch16/prob/DFC2020_tes_unet_resnext_16_.npy',\n",
       " '/data/DW/ordinary/DRN_RCN/DataFusion2020_TRAD_SEMISUPERVISED/pretes_track2/unet_resnet50_batch32/prob/DFC2020_tes_unet_resnet_32_.npy',\n",
       " '/data/DW/ordinary/DRN_RCN/DataFusion2020_TRAD_SEMISUPERVISED/pretes_track2/unet_resnet50_batch16/prob/DFC2020_tes_unet_resnet_16_.npy',\n",
       " '/data/DW/ordinary/DRN_RCN/DataFusion2020_TRAD_SEMISUPERVISED/pretes_track2/unet_senet50_batch16/prob/DFC2020_tes_unet_senet_16_.npy',\n",
       " '/data/DW/ordinary/DRN_RCN/DataFusion2020_TRAD_SEMISUPERVISED/pretes_track2/unet_cbamnext50_batch16/prob/DFC2020_tes_unet_cbamnext_16_.npy',\n",
       " '/data/DW/ordinary/DRN_RCN/DataFusion2020_TRAD_SEMISUPERVISED/pretes_track2/unet_senext50_batch32/prob/DFC2020_tes_unet_senext_32_.npy',\n",
       " '/data/DW/ordinary/DRN_RCN/DataFusion2020_TRAD_SEMISUPERVISED/pretes_track2/unet_resnext50_batch32/prob/DFC2020_tes_unet_resnext_32_.npy',\n",
       " '/data/DW/ordinary/DRN_RCN/DataFusion2020_TRAD_SEMISUPERVISED/pretes_track2/unet_resnet101_batch16/prob/DFC2020_tes_unet_resnet_16_.npy',\n",
       " '/data/DW/ordinary/DRN_RCN/DataFusion2020_TRAD_SEMISUPERVISED/pretes_track2/unet_senet50_batch32/prob/DFC2020_tes_unet_senet_32_.npy',\n",
       " '/data/DW/ordinary/DRN_RCN/DataFusion2020_TRAD_SEMISUPERVISED/pretes_track2/unet_cbamnet50_batch16/prob/DFC2020_tes_unet_cbamnet_16_.npy',\n",
       " '/data/DW/ordinary/DRN_RCN/DataFusion2020_TRAD_SEMISUPERVISED/pretes_track2/unet_cbamnet50_batch32/prob/DFC2020_tes_unet_cbamnet_32_.npy',\n",
       " '/data/DW/ordinary/DRN_RCN/DataFusion2020_TRAD_SEMISUPERVISED/pretes_track2/unet_cbamnext50_batch32/prob/DFC2020_tes_unet_cbamnext_32_.npy']"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "prob_files"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 13/13 [11:51<00:00, 59.21s/it]\n"
     ]
    }
   ],
   "source": [
    "for i in tqdm(range(len(prob_files))):\n",
    "    if i==0:\n",
    "        prob=np.load(prob_files[i])\n",
    "    else:\n",
    "        prob+=np.load(prob_files[i])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "prob=prob/13.0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "mask=prob>0.7"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(5128, 8, 256, 256)"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "mask.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "mask=mask.astype('uint8')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "filename"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 5128/5128 [00:00<00:00, 248594.44it/s]\n"
     ]
    }
   ],
   "source": [
    "def load_tesdata(base_dir):\n",
    "\n",
    "    s1_dir=glob(base_dir+'s1_0'+'/'+'*.tif')\n",
    "    s2_dir=[]\n",
    "    lc_dir=[]\n",
    "    for i in tqdm(range(len(s1_dir))):\n",
    "        s1_filename=os.path.basename(s1_dir[i])\n",
    "        former=s1_filename.split('_s1')[0]\n",
    "        ID=s1_filename.split('_s1')[1]\n",
    "        s2_dir.append(base_dir+'s2_0/'+former+'_s2'+ID)\n",
    "        lc_dir.append(base_dir + 'lc_0/' + former + '_lc' + ID)\n",
    "\n",
    "    s1_dir=np.array(s1_dir)\n",
    "    s2_dir=np.array(s2_dir)\n",
    "    lc_dir=np.array(lc_dir)\n",
    "\n",
    "    return s1_dir,s2_dir,lc_dir\n",
    "\n",
    "base_dir = '/data/PublicData/DF2020/test_track1/'\n",
    "s1_dir,s2_dir,lc_dir=load_tesdata(base_dir)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "pre_hex_color_dict={10:'000000',0:'009900',1:'c6b044',2:'fbff13',\n",
    "                    3:'b6ff05',4:'27ff87',5:'c24f44',\n",
    "                    6:'a5a5a5',7:'000000',8:'f9ffa4',9:'1c0dff'}\n",
    "\n",
    "def Hex_to_RGB(str):\n",
    "    r = int(str[0:2],16)\n",
    "    g = int(str[2:4],16)\n",
    "    b = int(str[4:6],16)\n",
    "    return [r,g,b]\n",
    "\n",
    "def DrawResult(labels, row, col):\n",
    "    num_class = 10\n",
    "\n",
    "    X_result = np.zeros((labels.shape[0], 3))\n",
    "    for i in range(num_class):\n",
    "        X_result[np.where(labels == i), 0] = Hex_to_RGB(pre_hex_color_dict[i])[0]\n",
    "        X_result[np.where(labels == i), 1] = Hex_to_RGB(pre_hex_color_dict[i])[1]\n",
    "        X_result[np.where(labels == i), 2] = Hex_to_RGB(pre_hex_color_dict[i])[2]\n",
    "\n",
    "    X_result = np.reshape(X_result, (row, col, 3))\n",
    "\n",
    "    return X_result"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "net2templab={0:0,\n",
    "             1:1,\n",
    "             2:3,\n",
    "             3:4,\n",
    "             4:5,\n",
    "             5:6,\n",
    "             6:8,\n",
    "             7:9}\n",
    "outputdir='/data/PublicData/DF2020/test_track1/lc_fake/'\n",
    "visdir='/data/PublicData/DF2020/test_track1/lc_fake/vis/'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 0/5128 [00:00<?, ?it/s]/home/sigma_wd/anaconda3/envs/pytorch/lib/python3.5/site-packages/ipykernel_launcher.py:35: DeprecationWarning: `imsave` is deprecated!\n",
      "`imsave` is deprecated in SciPy 1.0.0, and will be removed in 1.2.0.\n",
      "Use ``imageio.imwrite`` instead.\n",
      "100%|██████████| 5128/5128 [03:47<00:00, 22.61it/s]\n"
     ]
    }
   ],
   "source": [
    "for i in tqdm(range(5128)):\n",
    "    \n",
    "    tmp=mask[i]#8,256,256\n",
    "    \n",
    "    summ=np.sum(tmp,axis=0).squeeze()#may be have bg\n",
    "    \n",
    "    tmp_lab=np.argmax(tmp,axis=0).squeeze()\n",
    "    \n",
    "    # mapping\n",
    "    \n",
    "    h,w=tmp_lab.shape\n",
    "\n",
    "    tmp_lab=tmp_lab.reshape(-1)\n",
    "\n",
    "    label = list(map(lambda x: net2templab[x], tmp_lab))\n",
    "\n",
    "    label = np.array(label)  # list->array\n",
    "\n",
    "    label= label.reshape(h, w)\n",
    "    \n",
    "    # mask bg\n",
    "    \n",
    "    label[np.where(summ==0)]=7\n",
    "    \n",
    "    # image\n",
    "    \n",
    "    im = np.uint8(label + 1)\n",
    "    \n",
    "    # save\n",
    "    \n",
    "    filename=os.path.basename(lc_dir[i])\n",
    "    former = filename.split('lc')[0]\n",
    "    latter = filename.split('lc')[1]\n",
    "\n",
    "    imsave(outputdir + former + 'lc' + latter, im)\n",
    "    im_rgb = Image.fromarray(np.uint8(DrawResult(label.reshape(-1), 256, 256)))\n",
    "    im_rgb.save(visdir + former + 'lc' + latter[:-4] + '_vis.png')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python (py4torch_tf)",
   "language": "python",
   "name": "pytorch"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.6"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
